library(pacman)
p_load(
  here, fst, data.table, ggplot2, magrittr, janitor, stringr, purrr, collapse, lubridate
)

# Loading data ----------------------------------------------------------------
# Marginal production model
model_dt = 
  fread(
    here("Data/electricity-generation/output/PPRegressors-oge.csv")
  )[,plant_id_eia := as.character(orispl_code)] |> 
  setkey(plant_id_eia)
# Creating a long version (so rows are coefs, does not have intercept)
model_dt_long = 
  melt(
    model_dt, 
    id.vars = 'plant_id_eia',
    measure.vars = patterns('excess_load_')
  )[value != 0,.(
    plant_id_eia, 
    nerc_adj = str_remove_all(variable, 'excess_load_|_sq'),
    square = str_detect(variable, '_sq$'),
    coef = value
  )] |>
  merge( # Adding the estimated excess load where max/min happens
    melt(
      model_dt,
      id.vars = 'plant_id_eia', 
      measure.vars = patterns('est_extremum')
    )[!is.na(value),.(
      plant_id_eia, 
      nerc_adj = str_remove(variable, 'est_extremum_'),
      est_extremum = value
    )],
    by = c('plant_id_eia','nerc_adj'),
    all.x = TRUE
  )
# Some name cleaning 
setnames(
  model_dt, 
  old = str_subset(colnames(model_dt), 'sq'),
  new = paste0(
    'excess_load_sq_',
    str_subset(colnames(model_dt), 'sq') |>
      str_remove('_sq') |>
      str_remove('excess_load_')
  ))
setnames(
  model_dt, 
  old = str_subset(colnames(model_dt), 'excess_load'),
  new = paste0('coef_',str_subset(colnames(model_dt), 'excess_load'))
)
# Load data
nerc_load_dt = 
  fread(
    here("Data/electricity-generation/output/RegLoad-oge.csv")
  )[,.(
    nerc_adj = tolower(nerc_adj), 
    datetime_utc = utc_time, 
    excess_load, 
    excess_load_sq = excess_load^2
  )]
# Emissions rates
emissions_rate_dt = 
  read.fst(
    path = here('Data/electricity-generation/output/emissions-md-dt.fst'),
    as.data.table = TRUE
  )[,plant_id_eia := as.character(plant_id_eia)] |>
  dcast(
    plant_id_eia + net_generation_mwh_split ~ output + pollutant,
    value.var = c('emissions_rate_tons')
  )
# Marginal damages 
tot_md_dt = fread(
  here('Data/electricity-generation/output/damage-per-mwh-oge.csv')
)[,plant_id_eia := as.character(plant_id_eia)] 
# Plant info table
plant_dt =
  rbind(
    read.fst(
      path = here("Data/electricity-generation/plant-info-dt-oge.fst"),
      as.data.table = TRUE
    )[, 
      .SD[which.min(year)], 
      by = .(plant_id_eia = as.character(plant_id_eia))
    ],
    read.fst(
      path = here("Data/electricity-generation/shaped-info-dt-oge.fst"),
      as.data.table = TRUE
    )[, 
      .SD[which.min(year)], 
      by = .(plant_id_eia = as.character(plant_id_eia))
    ],
    fill = TRUE
  ) |>
  merge(model_dt, by = 'plant_id_eia')|>
  merge(emissions_rate_dt, by = 'plant_id_eia') |>
  merge(tot_md_dt, by = 'plant_id_eia') %>%
  .[,.(
    plant_id_eia, 
    namepcap, 
    log_scale, 
    intercept,
    net_generation_mwh_split = net_generation_mwh_split.x, 
    md_per_mwh_high, 
    md_per_mwh_low,
    emissions_rate_low_co2e,
    emissions_rate_high_co2e,
    emissions_rate_low_nox,
    emissions_rate_high_nox,
    emissions_rate_low_so2,
    emissions_rate_high_so2,
    emissions_rate_low_pm25,
    emissions_rate_high_pm25
  )] |>
  setkey(plant_id_eia)

# Function to create table of epsilons for each plant 
generate_epsilons = function(plant_dt, col_name, n_each = 1000){
  epsilon_dt = 
    map_dfr(
      plant_dt$plant_id_eia,
      \(id){
        data.table(
          plant_id_eia = as.character(id),
          epsilon_id = 1:n_each,
          epsilon = rnorm(
            n_each, mean = 0, 
            sd = exp(plant_dt[plant_id_eia == as.character(id)]$log_scale)
          )
        )
      }
    ) |> 
    setkey(plant_id_eia) 
  return(epsilon_dt)
}

# Calculating fitted values of production
calc_production_quants = function(
    plant_id_eia_in, 
    load_dt, 
    plant_dt, 
    model_dt_long,
    epsilon_dt
){
  print(plant_id_eia_in)
  # First merging together
  prod_dt = 
    merge(
      load_dt,
      model_dt_long[plant_id_eia == plant_id_eia_in],
      by = c('nerc_adj','square'),
      allow.cartesian = TRUE
    )
  # Returning empty table for intercept only models
  if(nrow(prod_dt) == 0){
    return(
      data.table(
        plant_id_eia = character(),
        time_id = numeric(),
        tot_eload = numeric(),
        fit_net_generation_mwh = numeric(),
        fit_emissions_tons_co2e = numeric(),
        fit_emissions_tons_so2 = numeric(),
        fit_emissions_tons_nox = numeric(),
        fit_emissions_tons_pm25 = numeric(),
        damages = numeric(),
        md_cal = numeric(),
        md_mro = numeric(),
        md_npcc = numeric(),
        md_rfc = numeric(),
        md_serc = numeric(),
        md_tre = numeric(),
        md_wecc = numeric()
      )
    )
  }
  # First summing coefs 
  coef_prod_dt = 
    prod_dt[,.(
      plant_id_eia, 
      time_id,
      fit_net_gen = fcase(
        est_extremum <= 0 | is.na(est_extremum), coef*excess_load, 
        est_extremum > 0 & excess_load <= est_extremum, coef*excess_load,
        est_extremum > 0 & excess_load > est_extremum, coef*est_extremum,
        default = NA
      )
    )] |>
    gby(plant_id_eia, time_id) |>
    fsum() |>
    setkey(plant_id_eia, time_id)    
  # Merging to intercept and log scale table 
  # Losing some plants here that are intercept only
  gen_dt = 
    merge(
      plant_dt,
      coef_prod_dt,
      by = 'plant_id_eia'
    )|>
    merge(
      epsilon_dt, 
      by = 'plant_id_eia',
      allow.cartesian = TRUE
    )
    # First adding intercept
    gen_dt[, 
      fit_net_generation_mwh_raw := intercept + 
        fifelse(is.na(fit_net_gen), 0, fit_net_gen)
    ]
    # Now calculating raw fitted value (censored)
    gen_dt[,
      fit_net_generation_mwh := fcase(
        fit_net_generation_mwh_raw + epsilon >= 0 
        & fit_net_generation_mwh_raw + epsilon <= namepcap, 
        fit_net_generation_mwh_raw + epsilon,
        fit_net_generation_mwh_raw + epsilon < 0, 0,
        fit_net_generation_mwh_raw + epsilon > namepcap, namepcap
      )
    ]
    # Removing columns we don't need anymore 
    gen_dt[,c('namepcap','log_scale','intercept','epsilon_id','epsilon') := NULL]
    # Now calculating emissions and damages
    gen_dt[,
      fit_emissions_tons_co2e := fifelse(
        fit_net_generation_mwh >= net_generation_mwh_split,
        (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_co2e
        + net_generation_mwh_split*emissions_rate_low_co2e,
        fit_net_generation_mwh*emissions_rate_low_co2e
      )
    ]
    gen_dt[,
      fit_emissions_tons_so2 := fifelse(
        fit_net_generation_mwh >= net_generation_mwh_split,
        (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_so2
        + net_generation_mwh_split*emissions_rate_low_so2,
        fit_net_generation_mwh*emissions_rate_low_so2
      )
    ]
    gen_dt[,
      fit_emissions_tons_nox := fifelse(
        fit_net_generation_mwh >= net_generation_mwh_split,
        (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_nox
        + net_generation_mwh_split*emissions_rate_low_nox,
        fit_net_generation_mwh*emissions_rate_low_nox
      )
    ]
    gen_dt[,
      fit_emissions_tons_pm25 := fifelse(
        fit_net_generation_mwh >= net_generation_mwh_split,
        (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_pm25
        + net_generation_mwh_split*emissions_rate_low_pm25,
        fit_net_generation_mwh*emissions_rate_low_pm25
      )
    ]
    gen_dt[,
      damages := fifelse(
        fit_net_generation_mwh >= net_generation_mwh_split,
        (fit_net_generation_mwh-net_generation_mwh_split)*md_per_mwh_high
        + net_generation_mwh_split*md_per_mwh_low,
        fit_net_generation_mwh*md_per_mwh_low
      )
    ]
  # Taking the average over the epsilon draws
  plant_gen_dt = 
    gen_dt[,.(
      plant_id_eia, 
      time_id, 
      fit_net_generation_mwh, 
      fit_emissions_tons_co2e,
      fit_emissions_tons_so2,
      fit_emissions_tons_nox,
      fit_emissions_tons_pm25,
      damages
    )] |>
    gby(plant_id_eia, time_id) |>
    fmean()
  # Saving the results and cleaning up
  write.fst(
    x = plant_gen_dt, 
    path = here(
      'Data/electricity-generation/output/predicted-production',
      paste0('predicted-production-',plant_id_eia_in,'.fst')
    )
  )
  rm(gen_dt, prod_dt, coef_prod_dt, plant_gen_dt)
}


sim_eload_quants = function(plant_id_eia_in, load_dt, n_eps_draws = 1000){
  # Drawing epsilons 
  epsilon_dt = 
    generate_epsilons(
      plant_dt = plant_dt[plant_id_eia == plant_id_eia_in], 
      #vec_in = unique(load_dt$time_id),
      col_name = 'time_id',
      n_each = n_eps_draws
    ) 
  # Calculating production
  quant_plant_dt = 
    calc_production_quants(
      plant_id_eia_in = plant_id_eia_in,
      load_dt = load_dt, 
      plant_dt = plant_dt, 
      model_dt_long = model_dt_long, 
      epsilon_dt = epsilon_dt
    )    
}

# Creating id's for every hour
datetime_dt =
  unique(nerc_load_dt[,.(datetime_utc)])[
    order(datetime_utc)  
  ][,
    time_id := 1:17538
  ]

# Getting load for each time_id 
actual_load_dt = 
  rbind(
    nerc_load_dt[,.(nerc_adj, datetime_utc, square = FALSE, excess_load)],
    nerc_load_dt[,.(nerc_adj, datetime_utc, square = TRUE, excess_load = excess_load_sq)]
  )[year(datetime_utc) == 2019] |>
  merge(
    datetime_dt, 
    by = 'datetime_utc'
  ) %>%
  .[,-'datetime_utc']
# Running it!
set.seed(218)
map(
  plant_dt$plant_id_eia,
  sim_eload_quants,
  load_dt = actual_load_dt
)
# combining and saving the results 
pred_elec_gen_dt = 
  map_dfr(
    list.files(
      path = here('Data/electricity-generation/output/predicted-production'), 
      full.names = TRUE
    ),
    read.fst, 
    as.data.table = TRUE
  )
# Final table 
pred_elec_gen_dt = 
  merge(
    pred_elec_gen_dt, 
    datetime_dt,
    by = 'time_id'
  )[,.(
    datetime_utc, 
    plant_id_eia, 
    fit_net_generation_mwh, 
    fit_emissions_tons_co2e,
    fit_emissions_tons_so2,
    fit_emissions_tons_nox,
    fit_emissions_tons_pm25,
    damages
  )] |>
  setkey(datetime_utc, plant_id_eia)
# Merging with actual production
elec_gen_dt=
  read.fst(
    path = here("Data/electricity-generation/plant-model-fit/elec-gen-fit-dt.fst"),
    as.data.table = TRUE
  )[year(datetime_utc) == 2019,.(
    datetime_utc, 
    plant_id_eia = as.character(plant_id_eia), 
    net_generation_mwh,
    co2e_mass_lb_for_electricity,
    nox_mass_lb_for_electricity,
    so2_mass_lb_for_electricity
  )]|>
  setkey(datetime_utc, plant_id_eia)
pred_elec_gen_dt = 
    merge(
        pred_elec_gen_dt, 
        elec_gen_dt, 
        by = c('datetime_utc','plant_id_eia'),
        all = TRUE
    )
# Saving 
write.fst(
  x = pred_elec_gen_dt, 
  path = here('Data/electricity-generation/output/pred-elec-gen-dt.fst')
)





rm(elec_gen_dt)
gc()
# Plots: Actual vs predicted (total production) 
#choosing random plants 
set.seed(1234)
plant_ids = sample(unique(pred_elec_gen_dt$plant_id_eia), 6)
  model_fit_p = 
    ggplot(
      pred_elec_gen_dt[plant_id_eia %in% plant_ids], 
      aes(
        x = net_generation_mwh, 
        y = fit_net_generation_mwh
      )
    ) + 
    geom_point(alpha = 0.05, color = 'black', shape = 19) +
    geom_smooth() +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
    facet_wrap(
      ~plant_id_eia,
      scales = "free"
    )+ 
    labs(
      x = "Actual Production (MWh)",
      y = "Predicted Production (MWh)"
    ) + 
    theme_minimal() 
    model_fit_p
    # Total production hours predicted vs actual 
    plant_hours_p =
        pred_elec_gen_dt[,.(
            Actual = mean(net_generation_mwh > 0),
            Predicted = mean(fit_net_generation_mwh > 0),
            tot_mwh = sum(net_generation_mwh, na.rm = TRUE)), 
        by = .(plant_id_eia)
        ] |>
        ggplot(aes(x = Actual, y = Predicted, size = tot_mwh, weight = tot_mwh)) + 
        geom_point(alpha = 0.3, color = 'gray30', shape = 19) +
        geom_smooth() +
        geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
        # facet_wrap(
        #   ~ic,
        #   scales = "free"
        # )+ 
        scale_x_continuous(
        name = 'Actual Production Hours', 
        labels = scales::label_percent()
        )+ 
        scale_y_continuous(
        name = 'Predicted Production Hours', 
        labels = scales::label_percent()
        ) + 
        theme_minimal() +
        guides(size = 'none')


  # Intermittency: Number of startups/shut downs 
  p_load(collapse)
  intermittency_dt = 
    pred_elec_gen_dt[,.(
      datetime_utc, 
      plant_id_eia, 
      net_generation_mwh, 
      fit_net_generation_mwh
    )] |>
    gby(plant_id_eia) |>
    flag(t = datetime_utc, n = 0:1) %>%
    .[,.(
      startup_actual = fsum(net_generation_mwh > 0 & L1.net_generation_mwh == 0),
      startup_pred = fsum(fit_net_generation_mwh > 0 & L1.fit_net_generation_mwh == 0),
      shutdown_actual = fsum(net_generation_mwh == 0 & L1.net_generation_mwh > 0),
      shutdown_pred = fsum(fit_net_generation_mwh == 0 & L1.fit_net_generation_mwh > 0)), keyby = .(plant_id_eia)
    ]

ggplot(intermittency_dt) + 
    geom_point(aes(x = startup_actual, y = startup_pred)) + 
    geom_abline(slope = 1, intercept = 0) 
  #geom_density(aes(x = startup_actual)) + 
  #geom_density(aes(x = startup_pred), color = 'red') 

pred_elec_gen_dt[plant_id_eia == 55333 & datetime_utc <= ymd('2019-06-01')] |>
  ggplot(aes(x = datetime_utc)) +
  geom_point(aes(y = net_generation_mwh))+
  geom_point(aes(y = fit_net_generation_mwh), color= 'red')


# 10 random plants: We should do this for many epsilons 
set.seed(1234)
plant_ids = sample(unique(elec_gen_dt$plant_id_eia), size = 6)
pred_elec_gen_dt[plant_id_eia %in% plant_ids] |>
  ggplot(aes(x = net_generation_mwh, y = fit_net_generation_mwh)) + 
  geom_point(alpha = 0.05) + 
  geom_smooth(se = FALSE) + 
  geom_abline(slope = 1, intercept = 0, linetype = 'dashed') + 
  facet_wrap(~plant_id_eia, scales = 'free') + 
  theme_minimal()
